Лабораторная работа 4: Семантическая сегментация с использованием PyTorch¶
Набор данных.¶
Данные содержат 8 классов. Маска сегментации имеет вид трехканального изображения с пикселями, значения которых равно либо 0, либо 255, например, (0, 0, 0), (0, 0, 255) и так далее. Помимо этого встречаются и промежуточные значения, отличные от 0 и 255. В рамках данной лабораторной работы предлагается следующее преобразование: значения маски, меньшие 128, нужно установить в 0, а значения, равные или больше 128, установить в 255.
Для упрощения работы рекомендуется объединить следующие классы в один:
- класс 2 - Aquatic plants and sea-grass
- класс 3 - Wrecks and ruins
- класс 5 - Reefs and invertebrates
- класс 7 - Sea-floor and rocks
Требования¶
Необходимо выполнить и отобразить в Jupyter следующие задачи:
- Загрузка и проверка данных. Загрузить и предобработать данные с демонстрацией избранных изображений и соответствующих масок, чтобы подтвердить корректность загрузки и соответствие размерностей данных.
- Реализация архитектуры сети. Разработать архитектуру нейронной сети для сегментации с использованием фреймворка PyTorch.
- Настройка гиперпараметров обучения. Настроить параметры обучения, такие как функция ошибки, размер сети, скорость обучения и другие параметры.
- Тестирование модели. После завершения обучения для оценки качества работы необходимо посчитать accuracy, IoU и визуализировать confusion matrix (с нормализацией,
normalize='true'). - Визуализация результатов. После завершения обучения построить и отобразить результаты сегментации на тестовых изображениях, сравнивая с реальными масками сегментации.
Выбор архитектуры:
- Можно использовать или адаптировать известные архитектуры глубокого обучения.
- Может быть полезным:
- уменьшить количество параметров в нейронной сети и размер входного изображения для ускорения сходимости, предотвращения переобучения и ускорения работы нейронной сети.
- использовать аугментацию данных и взвешенные/специализированные функции ошибки. При аугментации данных необходимо учитывать связь изображений с маской классов.
- Использовать перенос знаний недопустимо.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix
import seaborn as sea
import matplotlib.pyplot as plt
from torchsummary import summary
from PIL import Image
import numpy as np
import time
Загрузка и проверка корректности данных
import os
import shutil
from google.colab import drive
drive.mount('/content/drive')
def copy_files_recursive(source_folder, destination_folder):
for root, dirs, files in os.walk(source_folder):
for file in files:
source_path = os.path.join(root, file)
destination_path = os.path.join(destination_folder, os.path.relpath(source_path, source_folder))
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
shutil.copyfile(source_path, destination_path)
Mounted at /content/drive
remote_root = '/content/drive/MyDrive/SUIM'
root = '/content/SUIM'
copy_files_recursive(remote_root, root)
number_classes = 5
classes = {
"background": [(0, 0, 0)],
"human_divers": [(0, 0, 1)],
"robots": [(1, 0, 0)],
"fish_vertebrates": [(1, 1, 0)],
"other": [(0, 1, 0), (0, 1, 1), (1, 0, 1), (1, 1, 1)]
}
color_classes = [[0, 0, 0],
[0, 0, 1],
[1, 0, 0],
[1, 1, 0],
[0, 1, 0]]
class CustomDataset(Dataset):
def __init__(self, images, masks):
self.images = torch.tensor(images, dtype = torch.float)
self.masks = torch.tensor(masks, dtype = torch.float)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
mask = self.masks[idx]
return image, mask
def load_dataset(root_images, root_masks, image_size):
images = []
list_dir = sorted(os.listdir(root_images))
for file_name in list_dir:
file_path = os.path.join(root_images, file_name)
if os.path.isfile(file_path):
with Image.open(file_path) as image:
resized_image = np.array(image.resize(image_size)) / 255
images.append(resized_image)
labels = []
list_dir = sorted(os.listdir(root_masks))
for file_name in list_dir:
file_path = os.path.join(root_masks, file_name)
if os.path.isfile(file_path):
with Image.open(file_path) as mask:
background = np.zeros(image_size)
human_divers = np.zeros(image_size)
robots = np.zeros(image_size)
fish_vertebrates = np.zeros(image_size)
other = np.zeros(image_size)
resized_mask = np.array(mask.resize(image_size)) / 255
resized_mask = np.where(resized_mask < 0.5, 0, 1)
for i in range(image_size[0]):
for j in range(image_size[1]):
if np.all(resized_mask[i, j] == classes["background"], axis = -1):
background[i, j] = 1
elif np.all(resized_mask[i, j] == classes["human_divers"], axis = -1):
human_divers[i, j] = 1
elif np.all(resized_mask[i, j] == classes["robots"], axis = -1):
robots[i, j] = 1
elif np.all(resized_mask[i, j] == classes["fish_vertebrates"], axis = -1):
fish_vertebrates[i, j] = 1
else:
other[i, j] = 1
labels.append(np.stack([background, human_divers, robots, fish_vertebrates, other], -1))
images = np.array(images)
labels = np.array(labels)
dataset = CustomDataset(images, labels)
return dataset
def dataset_info(dataset):
print("Размер датасета изображений:", dataset.images.shape)
print("Размер датасета масок:", dataset.masks.shape)
print()
number_pixels = {'Background': np.count_nonzero(dataset.masks[:, :, :, 0] == 1),
'Human divers': np.count_nonzero(dataset.masks[:, :, :, 1] == 1),
'Robots': np.count_nonzero(dataset.masks[:, :, :, 2] == 1),
'Fish and vertebrates': np.count_nonzero(dataset.masks[:, :, :, 3] == 1),
'Other': np.count_nonzero(dataset.masks[:, :, :, 4] == 1)}
sum_pixel = dataset.images.shape[0] * dataset.images.shape[1] * dataset.images.shape[2]
for key, value in number_pixels.items():
print(f'Класс: {key}, Число пикселей: {value}({(value / sum_pixel * 100):.2f}%)')
image_size = 80
root_train_images = "/content/SUIM/train_val/images"
root_train_masks = "/content/SUIM/train_val/masks"
train_dataset = load_dataset(root_train_images, root_train_masks, (image_size, image_size))
root_test_images = "/content/SUIM/TEST/images"
root_test_masks = "/content/SUIM/TEST/masks"
test_dataset = load_dataset(root_test_images, root_test_masks, (image_size, image_size))
print('Train dataset:')
dataset_info(train_dataset)
print()
print('Test dataset:')
dataset_info(test_dataset)
Train dataset: Размер датасета изображений: torch.Size([1525, 80, 80, 3]) Размер датасета масок: torch.Size([1525, 80, 80, 5]) Класс: Background, Число пикселей: 3034338(31.09%) Класс: Human divers, Число пикселей: 184114(1.89%) Класс: Robots, Число пикселей: 37740(0.39%) Класс: Fish and vertebrates, Число пикселей: 767257(7.86%) Класс: Other, Число пикселей: 5736551(58.78%) Test dataset: Размер датасета изображений: torch.Size([110, 80, 80, 3]) Размер датасета масок: torch.Size([110, 80, 80, 5]) Класс: Background, Число пикселей: 282598(40.14%) Класс: Human divers, Число пикселей: 20661(2.93%) Класс: Robots, Число пикселей: 4557(0.65%) Класс: Fish and vertebrates, Число пикселей: 54083(7.68%) Класс: Other, Число пикселей: 342101(48.59%)
def plot_images_with_masks(dataset, title):
vert_size = 6
horiz_size = 3
fig, axes = plt.subplots(vert_size, horiz_size * 2, figsize = (15, 15))
fig.suptitle(title)
mask_sizes = (image_size, image_size, 3)
count_images = vert_size * horiz_size
for number in range(count_images):
i = number // horiz_size
j = number % horiz_size
image, mask = dataset[number]
axes[i, j * 2].imshow(image, cmap=plt.cm.binary)
axes[i, j * 2].axis('off')
rgb_mask = np.zeros(mask_sizes)
for k in range(number_classes):
rgb_mask[mask[:, :, k] > 0] = color_classes[k]
axes[i, j * 2 + 1].imshow(image, cmap=plt.cm.binary)
axes[i, j * 2 + 1].imshow(rgb_mask, alpha = 0.35)
axes[i, j * 2 + 1].axis('off')
plt.tight_layout()
plt.show()
plot_images_with_masks(train_dataset, 'Examples from train dataset')
plot_images_with_masks(test_dataset, 'Examples from test dataset')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')
Device: cuda
Эксперименты¶
Гиперпараметры обучения
learning_rate = 0.01
epochs = 50
batch_size = 36
Разбиение датасетов на батчи
training_dataset, validation_dataset = torch.utils.data.random_split(train_dataset, [0.85, 0.15])
training_loader = DataLoader(training_dataset, batch_size = batch_size, shuffle = True)
validation_loader = DataLoader(validation_dataset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 'same'),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.25)
)
def forward(self, x):
return self.double_conv(x)
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x):
x = self.max_pool(x)
return self.conv(x)
class UpSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim = 1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size = 1)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self, n_channels = 3, n_classes = 5):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.input = DoubleConv(n_channels, 64)
self.down1 = DownSample(64, 128)
self.down2 = DownSample(128, 256)
self.down3 = DownSample(256, 512)
self.down4 = DownSample(512, 1024)
self.up1 = UpSample(1024 + 512, 512)
self.up2 = UpSample(512 + 256, 256)
self.up3 = UpSample(256 + 128, 128)
self.up4 = UpSample(128 + 64, 64)
self.output = OutConv(64, n_classes)
def forward(self, x):
x1 = self.input(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.output(x)
return torch.sigmoid(logits)
net = UNet().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr = learning_rate)
summary(net, (3, image_size, image_size))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 80, 80] 1,792
BatchNorm2d-2 [-1, 64, 80, 80] 128
ReLU-3 [-1, 64, 80, 80] 0
Dropout-4 [-1, 64, 80, 80] 0
DoubleConv-5 [-1, 64, 80, 80] 0
MaxPool2d-6 [-1, 64, 40, 40] 0
Conv2d-7 [-1, 128, 40, 40] 73,856
BatchNorm2d-8 [-1, 128, 40, 40] 256
ReLU-9 [-1, 128, 40, 40] 0
Dropout-10 [-1, 128, 40, 40] 0
DoubleConv-11 [-1, 128, 40, 40] 0
DownSample-12 [-1, 128, 40, 40] 0
MaxPool2d-13 [-1, 128, 20, 20] 0
Conv2d-14 [-1, 256, 20, 20] 295,168
BatchNorm2d-15 [-1, 256, 20, 20] 512
ReLU-16 [-1, 256, 20, 20] 0
Dropout-17 [-1, 256, 20, 20] 0
DoubleConv-18 [-1, 256, 20, 20] 0
DownSample-19 [-1, 256, 20, 20] 0
MaxPool2d-20 [-1, 256, 10, 10] 0
Conv2d-21 [-1, 512, 10, 10] 1,180,160
BatchNorm2d-22 [-1, 512, 10, 10] 1,024
ReLU-23 [-1, 512, 10, 10] 0
Dropout-24 [-1, 512, 10, 10] 0
DoubleConv-25 [-1, 512, 10, 10] 0
DownSample-26 [-1, 512, 10, 10] 0
MaxPool2d-27 [-1, 512, 5, 5] 0
Conv2d-28 [-1, 1024, 5, 5] 4,719,616
BatchNorm2d-29 [-1, 1024, 5, 5] 2,048
ReLU-30 [-1, 1024, 5, 5] 0
Dropout-31 [-1, 1024, 5, 5] 0
DoubleConv-32 [-1, 1024, 5, 5] 0
DownSample-33 [-1, 1024, 5, 5] 0
Upsample-34 [-1, 1024, 10, 10] 0
Conv2d-35 [-1, 512, 10, 10] 7,078,400
BatchNorm2d-36 [-1, 512, 10, 10] 1,024
ReLU-37 [-1, 512, 10, 10] 0
Dropout-38 [-1, 512, 10, 10] 0
DoubleConv-39 [-1, 512, 10, 10] 0
UpSample-40 [-1, 512, 10, 10] 0
Upsample-41 [-1, 512, 20, 20] 0
Conv2d-42 [-1, 256, 20, 20] 1,769,728
BatchNorm2d-43 [-1, 256, 20, 20] 512
ReLU-44 [-1, 256, 20, 20] 0
Dropout-45 [-1, 256, 20, 20] 0
DoubleConv-46 [-1, 256, 20, 20] 0
UpSample-47 [-1, 256, 20, 20] 0
Upsample-48 [-1, 256, 40, 40] 0
Conv2d-49 [-1, 128, 40, 40] 442,496
BatchNorm2d-50 [-1, 128, 40, 40] 256
ReLU-51 [-1, 128, 40, 40] 0
Dropout-52 [-1, 128, 40, 40] 0
DoubleConv-53 [-1, 128, 40, 40] 0
UpSample-54 [-1, 128, 40, 40] 0
Upsample-55 [-1, 128, 80, 80] 0
Conv2d-56 [-1, 64, 80, 80] 110,656
BatchNorm2d-57 [-1, 64, 80, 80] 128
ReLU-58 [-1, 64, 80, 80] 0
Dropout-59 [-1, 64, 80, 80] 0
DoubleConv-60 [-1, 64, 80, 80] 0
UpSample-61 [-1, 64, 80, 80] 0
Conv2d-62 [-1, 5, 80, 80] 325
OutConv-63 [-1, 5, 80, 80] 0
================================================================
Total params: 15,678,085
Trainable params: 15,678,085
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.07
Forward/backward pass size (MB): 82.03
Params size (MB): 59.81
Estimated Total Size (MB): 141.91
----------------------------------------------------------------
def train(net, train_loader, validation_loader, criterion, epochs):
for epoch in range(epochs):
loss_list = []
time_one = time.time()
for data in train_loader:
images = data[0].permute(0, 3, 1, 2).to(device)
labels = data[1].permute(0, 3, 1, 2).to(device)
outputs = net(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_list.append(loss)
diff_time = time.time() - time_one
loss_validation_list = []
time_one = time.time()
for data in validation_loader:
images = data[0].permute(0, 3, 1, 2).to(device)
labels = data[1].permute(0, 3, 1, 2).to(device)
outputs = net(images)
loss = criterion(outputs, labels)
loss_validation_list.append(loss)
diff_time_validation = time.time() - time_one
print(f"Epoch: {epoch + 1}/{epochs}, Train "
f"Loss: {torch.stack(loss_list).mean():.4f}, "
f"Time: {diff_time:.2f} Validation "
f"Loss: {torch.stack(loss_validation_list).mean():.4f}, "
f"Time: {diff_time_validation:.2f}")
train(net, training_loader, validation_loader, criterion, epochs)
Epoch: 1/50, Train Loss: 0.3104, Time: 6.52 Validation Loss: 0.2474, Time: 0.44 Epoch: 2/50, Train Loss: 0.2360, Time: 6.46 Validation Loss: 0.2362, Time: 0.44 Epoch: 3/50, Train Loss: 0.2245, Time: 6.49 Validation Loss: 0.2295, Time: 0.45 Epoch: 4/50, Train Loss: 0.2223, Time: 6.53 Validation Loss: 0.2209, Time: 0.45 Epoch: 5/50, Train Loss: 0.2197, Time: 6.56 Validation Loss: 0.2144, Time: 0.45 Epoch: 6/50, Train Loss: 0.2108, Time: 6.57 Validation Loss: 0.2093, Time: 0.45 Epoch: 7/50, Train Loss: 0.2107, Time: 6.57 Validation Loss: 0.2316, Time: 0.45 Epoch: 8/50, Train Loss: 0.2240, Time: 6.60 Validation Loss: 0.2081, Time: 0.45 Epoch: 9/50, Train Loss: 0.2034, Time: 6.63 Validation Loss: 0.2433, Time: 0.46 Epoch: 10/50, Train Loss: 0.2316, Time: 6.65 Validation Loss: 0.2114, Time: 0.46 Epoch: 11/50, Train Loss: 0.2001, Time: 6.66 Validation Loss: 0.2070, Time: 0.46 Epoch: 12/50, Train Loss: 0.1955, Time: 6.69 Validation Loss: 0.1991, Time: 0.46 Epoch: 13/50, Train Loss: 0.1875, Time: 6.70 Validation Loss: 0.1935, Time: 0.46 Epoch: 14/50, Train Loss: 0.2033, Time: 6.73 Validation Loss: 0.1912, Time: 0.47 Epoch: 15/50, Train Loss: 0.1925, Time: 6.74 Validation Loss: 0.1894, Time: 0.47 Epoch: 16/50, Train Loss: 0.1870, Time: 6.75 Validation Loss: 0.1891, Time: 0.46 Epoch: 17/50, Train Loss: 0.1854, Time: 6.76 Validation Loss: 0.1782, Time: 0.47 Epoch: 18/50, Train Loss: 0.1903, Time: 6.77 Validation Loss: 0.1854, Time: 0.47 Epoch: 19/50, Train Loss: 0.1804, Time: 6.78 Validation Loss: 0.1840, Time: 0.47 Epoch: 20/50, Train Loss: 0.1779, Time: 6.79 Validation Loss: 0.1780, Time: 0.47 Epoch: 21/50, Train Loss: 0.1723, Time: 6.81 Validation Loss: 0.1886, Time: 0.47 Epoch: 22/50, Train Loss: 0.1943, Time: 6.81 Validation Loss: 0.1729, Time: 0.47 Epoch: 23/50, Train Loss: 0.1789, Time: 6.82 Validation Loss: 0.1786, Time: 0.47 Epoch: 24/50, Train Loss: 0.1766, Time: 6.83 Validation Loss: 0.1799, Time: 0.47 Epoch: 25/50, Train Loss: 0.1621, Time: 6.84 Validation Loss: 0.1611, Time: 0.47 Epoch: 26/50, Train Loss: 0.1564, Time: 6.84 Validation Loss: 0.1699, Time: 0.47 Epoch: 27/50, Train Loss: 0.1520, Time: 6.85 Validation Loss: 0.1851, Time: 0.47 Epoch: 28/50, Train Loss: 0.1576, Time: 6.85 Validation Loss: 0.1693, Time: 0.47 Epoch: 29/50, Train Loss: 0.1522, Time: 6.87 Validation Loss: 0.1595, Time: 0.48 Epoch: 30/50, Train Loss: 0.1737, Time: 6.87 Validation Loss: 0.1658, Time: 0.48 Epoch: 31/50, Train Loss: 0.1665, Time: 6.87 Validation Loss: 0.1772, Time: 0.48 Epoch: 32/50, Train Loss: 0.1901, Time: 6.87 Validation Loss: 0.1819, Time: 0.48 Epoch: 33/50, Train Loss: 0.1714, Time: 6.88 Validation Loss: 0.1764, Time: 0.48 Epoch: 34/50, Train Loss: 0.1575, Time: 6.89 Validation Loss: 0.1655, Time: 0.47 Epoch: 35/50, Train Loss: 0.1525, Time: 6.89 Validation Loss: 0.1816, Time: 0.48 Epoch: 36/50, Train Loss: 0.1832, Time: 6.89 Validation Loss: 0.1725, Time: 0.48 Epoch: 37/50, Train Loss: 0.1644, Time: 6.89 Validation Loss: 0.1715, Time: 0.48 Epoch: 38/50, Train Loss: 0.1535, Time: 6.90 Validation Loss: 0.1592, Time: 0.48 Epoch: 39/50, Train Loss: 0.1496, Time: 6.89 Validation Loss: 0.1693, Time: 0.48 Epoch: 40/50, Train Loss: 0.1474, Time: 6.90 Validation Loss: 0.1576, Time: 0.48 Epoch: 41/50, Train Loss: 0.1414, Time: 6.90 Validation Loss: 0.1567, Time: 0.48 Epoch: 42/50, Train Loss: 0.1452, Time: 6.89 Validation Loss: 0.1592, Time: 0.48 Epoch: 43/50, Train Loss: 0.1335, Time: 6.89 Validation Loss: 0.1553, Time: 0.48 Epoch: 44/50, Train Loss: 0.1347, Time: 6.89 Validation Loss: 0.1624, Time: 0.48 Epoch: 45/50, Train Loss: 0.1383, Time: 6.90 Validation Loss: 0.1554, Time: 0.48 Epoch: 46/50, Train Loss: 0.1467, Time: 6.89 Validation Loss: 0.1548, Time: 0.48 Epoch: 47/50, Train Loss: 0.1238, Time: 6.91 Validation Loss: 0.1572, Time: 0.48 Epoch: 48/50, Train Loss: 0.1177, Time: 6.90 Validation Loss: 0.1543, Time: 0.48 Epoch: 49/50, Train Loss: 0.1138, Time: 6.90 Validation Loss: 0.1583, Time: 0.48 Epoch: 50/50, Train Loss: 0.1077, Time: 6.88 Validation Loss: 0.1564, Time: 0.48
def IoU(labels, predict):
intersection = np.logical_and(labels, predict)
union = np.logical_or(labels, predict)
if np.sum(union) == 0:
iou_score = 0
else:
iou_score = np.sum(intersection) / np.sum(union)
return iou_score
def metrics_compute(net, data_loader):
accuracy_list, IoU_list = [], []
cm_mask, cm_predict_mask = [], []
with torch.no_grad():
for images, masks in data_loader:
images = images.permute(0, 3, 1, 2).to(device)
masks = masks.permute(0, 3, 1, 2).numpy()
predict_masks = net(images)
predict_masks = torch.where(predict_masks < torch.tensor(0.5), torch.tensor(0), torch.tensor(1)).cpu().numpy()
for k in range(5):
masks_tmp = masks[:, k, :, :]
predict_masks_tmp = predict_masks[:, k, :, :]
masks_tmp = np.where(masks_tmp != 1, 0, k + 1)
predict_masks_tmp = np.where(predict_masks_tmp != 1, 0, k + 1)
masks[:, k, :, :] = masks_tmp
predict_masks[:, k, :, :] = predict_masks_tmp
cm_mask.append(masks)
cm_predict_mask.append(predict_masks)
temp_accucary, temp_iou = [], []
for k in range(5):
accuracy = np.mean(predict_masks[:, k, :, :] == masks[:, k, :, :])
iou = IoU(masks[:, k, :, :], predict_masks[:, k, :, :])
temp_accucary.append(accuracy)
temp_iou.append(iou)
accuracy_list.append(temp_accucary)
IoU_list.append(temp_iou)
accuracy_list = np.array(accuracy_list)
IoU_list = np.array(IoU_list)
print(f'Оценка Accuracy для каждого класса: {np.mean(accuracy_list, axis = 0, dtype = np.float16)}')
print(f'Оценка IoU для каждого класса: {np.mean(IoU_list, axis = 0, dtype = np.float16)}')
print(f'Оценка Accuracy на данных: {np.mean(accuracy_list, dtype = np.float16):.4f}')
print(f'Оценка IoU на данных: {np.mean(IoU_list, dtype = np.float16):.4f}')
cm_mask = np.concatenate(cm_mask, axis = 0)
cm_predict_mask = np.concatenate(cm_predict_mask, axis = 0)
cm_mask = cm_mask.flatten()
cm_predict_mask = cm_predict_mask.flatten()
name_class = ['Background', 'Human divers', 'Robots', 'Fish and vertebrates', 'Other']
cm = confusion_matrix(cm_mask, cm_predict_mask, labels = np.arange(5), normalize = 'true')
sea.heatmap(cm, annot = True, cmap = 'Blues', xticklabels = name_class, yticklabels = name_class)
plt.xlabel('Предсказанные классы')
plt.ylabel('Истинные классы')
plt.title('Confusion matrix')
plt.show()
metrics_compute(net, test_loader)
Оценка Accuracy для каждого класса: [0.911 0.979 0.995 0.9478 0.875 ] Оценка IoU для каждого класса: [0.792 0.1936 0.0718 0.2764 0.786 ] Оценка Accuracy на данных: 0.9414 Оценка IoU на данных: 0.4243
def plot_images_with_masks_test(net, dataset):
vert_size = 12
horiz_size = 2
fig, axes = plt.subplots(vert_size, horiz_size * 3, figsize = (15, 25))
fig.suptitle("Predicted vs. True")
mask_sizes = (image_size, image_size, 3)
count_images = vert_size * horiz_size
for number in range(count_images):
i = number // horiz_size
j = number % horiz_size
image, mask = dataset[number]
axes[i, j * 3].imshow(image, cmap=plt.cm.binary)
axes[i, j * 3].set_title('Image', fontsize = 10)
axes[i, j * 3].axis('off')
with torch.no_grad():
images = image
images = images.unsqueeze(0)
images = images.permute(0, 3, 1, 2).to(device)
predict_mask = net(images)
predict_mask = torch.where(predict_mask < torch.tensor(0.5), torch.tensor(0), torch.tensor(1)).permute(0, 2, 3, 1).cpu()
rgb_predicted_mask = np.zeros(mask_sizes)
for k in range(number_classes):
rgb_predicted_mask[predict_mask[0, :, :, k] > 0] = color_classes[k]
axes[i, j * 3 + 1].imshow(image, cmap=plt.cm.binary)
axes[i, j * 3 + 1].imshow(rgb_predicted_mask, alpha = 0.35)
axes[i, j * 3 + 1].set_title('Predicted', fontsize = 10)
axes[i, j * 3 + 1].axis('off')
rgb_mask = np.zeros(mask_sizes)
for k in range(number_classes):
rgb_mask[mask[:, :, k] > 0] = color_classes[k]
axes[i, j * 3 + 2].imshow(image, cmap=plt.cm.binary)
axes[i, j * 3 + 2].imshow(rgb_mask, alpha = 0.35)
axes[i, j * 3 + 2].set_title('True', fontsize = 10)
axes[i, j * 3 + 2].axis('off')
plt.tight_layout()
plt.show()
plot_images_with_masks_test(net, test_dataset)
Аугментация данных и перебалансировка классов
В тренировочных данных у нас большой перекос данных в сторону классов "Background" и "Other". Попробуем использовать аугментацию для увеличения количества пикселей маленьких классов и уберем картинки, на которых большая часть это "Background" и "Other"
def data_rebalancing(dataset, coeff):
new_images = []
new_masks = []
mask_sizes = (image_size, image_size, 3)
count_pixels = image_size * image_size
for image, mask in dataset:
rgb_mask = np.zeros(mask_sizes)
for k in range(number_classes):
rgb_mask[mask[:, :, k] > 0] = color_classes[k]
count_colors = [0, 0, 0, 0, 0]
for i in range(image_size):
for j in range(image_size):
for k in range(number_classes):
if np.all(rgb_mask[i, j] == color_classes[k], axis=-1):
count_colors[k] += 1
back_other_colors = count_colors[0] + count_colors[4]
small_classes_colors = count_pixels - back_other_colors
if back_other_colors / count_pixels < 1 - coeff:
print('Количество пикселей на класс: ', count_colors)
new_images.append(image)
new_masks.append(mask)
down_image = np.flipud(image)
down_mask = np.flipud(mask)
new_images.append(down_image)
new_masks.append(down_mask)
right_image = np.fliplr(image)
right_mask = np.fliplr(mask)
new_images.append(right_image)
new_masks.append(right_mask)
right_down_image = np.flipud(right_image)
right_down_mask = np.flipud(right_mask)
new_images.append(right_down_image)
new_masks.append(right_down_mask)
new_images = np.array(new_images)
new_masks = np.array(new_masks)
new_dataset = CustomDataset(new_images, new_masks)
return new_dataset
new_train_dataset = data_rebalancing(train_dataset, 0.2)
Количество пикселей на класс: [4856, 690, 0, 764, 90] Количество пикселей на класс: [1290, 1321, 1371, 0, 2418] Количество пикселей на класс: [4886, 463, 1051, 0, 0] Количество пикселей на класс: [1576, 1815, 385, 0, 2624] Количество пикселей на класс: [1451, 146, 0, 1502, 3301] Количество пикселей на класс: [2011, 1674, 0, 1769, 946] Количество пикселей на класс: [4962, 1081, 357, 0, 0] Количество пикселей на класс: [5011, 403, 986, 0, 0] Количество пикселей на класс: [1295, 1459, 0, 2067, 1579] Количество пикселей на класс: [1543, 2165, 0, 0, 2692] Количество пикселей на класс: [4010, 0, 2390, 0, 0] Количество пикселей на класс: [2047, 1172, 127, 0, 3054] Количество пикселей на класс: [5016, 1384, 0, 0, 0] Количество пикселей на класс: [2446, 1694, 0, 0, 2260] Количество пикселей на класс: [4698, 1577, 124, 0, 1] Количество пикселей на класс: [1394, 1352, 0, 2578, 1076] Количество пикселей на класс: [2520, 1493, 0, 0, 2387] Количество пикселей на класс: [2620, 2186, 445, 0, 1149] Количество пикселей на класс: [5085, 508, 159, 648, 0] Количество пикселей на класс: [2467, 1828, 0, 0, 2105] Количество пикселей на класс: [4972, 749, 679, 0, 0] Количество пикселей на класс: [4856, 907, 637, 0, 0] Количество пикселей на класс: [4410, 1615, 375, 0, 0] Количество пикселей на класс: [1397, 1452, 0, 983, 2568] Количество пикселей на класс: [2498, 1366, 1, 0, 2535] Количество пикселей на класс: [2209, 1101, 337, 0, 2753] Количество пикселей на класс: [1091, 1296, 532, 0, 3481] Количество пикселей на класс: [1984, 1103, 266, 0, 3047] Количество пикселей на класс: [0, 1385, 0, 0, 5015] Количество пикселей на класс: [1497, 1460, 0, 1944, 1499] Количество пикселей на класс: [2774, 617, 73, 1634, 1302] Количество пикселей на класс: [5076, 847, 323, 154, 0] Количество пикселей на класс: [1917, 2149, 0, 0, 2334] Количество пикселей на класс: [2761, 667, 32, 1283, 1657] Количество пикселей на класс: [1583, 1649, 0, 1, 3167] Количество пикселей на класс: [656, 0, 0, 3583, 2161] Количество пикселей на класс: [2621, 0, 1, 1636, 2142] Количество пикселей на класс: [750, 0, 0, 1767, 3883] Количество пикселей на класс: [0, 0, 3, 1732, 4665] Количество пикселей на класс: [1288, 0, 5, 2769, 2338] Количество пикселей на класс: [1192, 1, 4, 1897, 3306] Количество пикселей на класс: [0, 0, 1, 4332, 2067] Количество пикселей на класс: [0, 0, 0, 1560, 4840] Количество пикселей на класс: [1590, 0, 0, 3469, 1341] Количество пикселей на класс: [4608, 0, 0, 1420, 372] Количество пикселей на класс: [270, 0, 5, 3212, 2913] Количество пикселей на класс: [2621, 0, 0, 3376, 403] Количество пикселей на класс: [2354, 0, 1, 1964, 2081] Количество пикселей на класс: [4690, 0, 0, 1710, 0] Количество пикселей на класс: [776, 0, 0, 5451, 173] Количество пикселей на класс: [0, 0, 1, 1614, 4785] Количество пикселей на класс: [4002, 0, 0, 1727, 671] Количество пикселей на класс: [2650, 0, 13, 2130, 1607] Количество пикселей на класс: [435, 0, 2, 1290, 4673] Количество пикселей на класс: [932, 0, 4, 1450, 4014] Количество пикселей на класс: [404, 0, 5, 3041, 2950] Количество пикселей на класс: [3609, 0, 0, 2143, 648] Количество пикселей на класс: [2903, 0, 0, 2942, 555] Количество пикселей на класс: [720, 0, 4, 4113, 1563] Количество пикселей на класс: [2839, 0, 0, 1732, 1829] Количество пикселей на класс: [405, 0, 1, 4187, 1807] Количество пикселей на класс: [2895, 62, 5, 1624, 1814] Количество пикселей на класс: [3045, 0, 6, 1433, 1916] Количество пикселей на класс: [3142, 0, 0, 1434, 1824] Количество пикселей на класс: [0, 0, 3, 1702, 4695] Количество пикселей на класс: [3078, 0, 0, 1886, 1436] Количество пикселей на класс: [0, 0, 1, 1308, 5091] Количество пикселей на класс: [2298, 1, 3, 1564, 2534] Количество пикселей на класс: [395, 0, 2, 1541, 4462] Количество пикселей на класс: [2802, 0, 0, 1577, 2021] Количество пикселей на класс: [1404, 0, 1, 1865, 3130] Количество пикселей на класс: [3004, 1, 0, 2329, 1066] Количество пикселей на класс: [4851, 0, 0, 1549, 0] Количество пикселей на класс: [432, 0, 1, 2364, 3603] Количество пикселей на класс: [1388, 0, 16, 2875, 2121] Количество пикселей на класс: [0, 0, 0, 1430, 4970] Количество пикселей на класс: [0, 0, 0, 1533, 4867] Количество пикселей на класс: [0, 0, 0, 2063, 4337] Количество пикселей на класс: [0, 0, 1, 1959, 4440] Количество пикселей на класс: [0, 0, 0, 1656, 4744] Количество пикселей на класс: [3067, 1, 5, 2488, 839] Количество пикселей на класс: [4580, 0, 0, 1820, 0] Количество пикселей на класс: [1851, 4, 0, 2502, 2043] Количество пикселей на класс: [0, 0, 2, 2604, 3794] Количество пикселей на класс: [0, 0, 0, 1658, 4742] Количество пикселей на класс: [1061, 0, 10, 2821, 2508] Количество пикселей на класс: [131, 0, 0, 2636, 3633] Количество пикселей на класс: [3971, 0, 0, 1490, 939] Количество пикселей на класс: [4968, 0, 0, 1432, 0] Количество пикселей на класс: [1688, 0, 0, 2404, 2308] Количество пикселей на класс: [2749, 0, 11, 1995, 1645] Количество пикселей на класс: [0, 0, 0, 2152, 4248] Количество пикселей на класс: [3271, 2, 0, 1727, 1400] Количество пикселей на класс: [2408, 0, 4, 2246, 1742] Количество пикселей на класс: [1658, 0, 0, 1318, 3424] Количество пикселей на класс: [0, 0, 1, 1928, 4471] Количество пикселей на класс: [1048, 0, 18, 1274, 4060] Количество пикселей на класс: [2474, 0, 0, 2235, 1691] Количество пикселей на класс: [2453, 0, 0, 1815, 2132] Количество пикселей на класс: [4869, 0, 0, 1531, 0] Количество пикселей на класс: [0, 0, 0, 1476, 4924] Количество пикселей на класс: [4602, 0, 0, 1798, 0] Количество пикселей на класс: [668, 0, 0, 1865, 3867] Количество пикселей на класс: [0, 0, 0, 1289, 5111] Количество пикселей на класс: [4025, 0, 0, 1766, 609] Количество пикселей на класс: [0, 0, 3, 1554, 4843] Количество пикселей на класс: [650, 0, 21, 3817, 1912] Количество пикселей на класс: [0, 0, 0, 1406, 4994] Количество пикселей на класс: [1035, 1, 3, 3154, 2207] Количество пикселей на класс: [1013, 0, 4, 1285, 4098] Количество пикселей на класс: [3577, 0, 0, 2141, 682] Количество пикселей на класс: [0, 0, 0, 1420, 4980] Количество пикселей на класс: [896, 0, 0, 3546, 1958] Количество пикселей на класс: [2049, 0, 5, 1382, 2964] Количество пикселей на класс: [0, 0, 0, 2029, 4371] Количество пикселей на класс: [0, 0, 0, 1804, 4596] Количество пикселей на класс: [2822, 0, 6, 1897, 1675] Количество пикселей на класс: [0, 0, 0, 2898, 3502] Количество пикселей на класс: [1369, 0, 2, 1448, 3581] Количество пикселей на класс: [2147, 0, 3, 3609, 641] Количество пикселей на класс: [0, 0, 3, 1622, 4775] Количество пикселей на класс: [80, 0, 0, 1659, 4661] Количество пикселей на класс: [0, 0, 2, 1795, 4603] Количество пикселей на класс: [0, 0, 0, 1298, 5102] Количество пикселей на класс: [0, 0, 2, 2104, 4294] Количество пикселей на класс: [1748, 0, 5, 1926, 2721] Количество пикселей на класс: [0, 0, 1, 2683, 3716] Количество пикселей на класс: [5041, 0, 0, 1359, 0] Количество пикселей на класс: [830, 0, 3, 1727, 3840] Количество пикселей на класс: [0, 0, 1, 2048, 4351] Количество пикселей на класс: [2443, 1, 4, 1535, 2417] Количество пикселей на класс: [4783, 0, 0, 1617, 0] Количество пикселей на класс: [497, 1, 3, 2782, 3117] Количество пикселей на класс: [3262, 1, 4, 1459, 1674] Количество пикселей на класс: [1336, 0, 6, 2786, 2272] Количество пикселей на класс: [0, 0, 1, 1495, 4904] Количество пикселей на класс: [0, 0, 1, 1782, 4617] Количество пикселей на класс: [2204, 0, 2, 2222, 1972] Количество пикселей на класс: [12, 0, 4, 2218, 4166] Количество пикселей на класс: [4986, 0, 0, 1414, 0] Количество пикселей на класс: [3355, 0, 10, 1350, 1685] Количество пикселей на класс: [0, 0, 1, 1357, 5042] Количество пикселей на класс: [2354, 0, 0, 1503, 2543] Количество пикселей на класс: [992, 0, 0, 1325, 4083] Количество пикселей на класс: [4932, 0, 0, 1468, 0] Количество пикселей на класс: [52, 0, 4, 2861, 3483] Количество пикселей на класс: [2489, 0, 13, 1838, 2060] Количество пикселей на класс: [0, 1585, 258, 0, 4557] Количество пикселей на класс: [1404, 0, 5, 3767, 1224] Количество пикселей на класс: [2751, 0, 0, 3649, 0] Количество пикселей на класс: [1772, 1, 1, 2791, 1835] Количество пикселей на класс: [0, 0, 2, 3043, 3355] Количество пикселей на класс: [1243, 0, 5, 2318, 2834] Количество пикселей на класс: [1425, 0, 2, 3150, 1823] Количество пикселей на класс: [0, 0, 0, 2203, 4197] Количество пикселей на класс: [3591, 0, 0, 2106, 703] Количество пикселей на класс: [0, 0, 3, 1530, 4867] Количество пикселей на класс: [0, 0, 8, 3853, 2539] Количество пикселей на класс: [2330, 60, 3, 2093, 1914] Количество пикселей на класс: [0, 0, 0, 1362, 5038] Количество пикселей на класс: [4402, 0, 0, 1998, 0] Количество пикселей на класс: [0, 0, 3, 1679, 4718] Количество пикселей на класс: [2178, 80, 6, 2605, 1531] Количество пикселей на класс: [87, 1, 77, 1765, 4470] Количество пикселей на класс: [29, 0, 38, 1668, 4665] Количество пикселей на класс: [1283, 0, 0, 5117, 0] Количество пикселей на класс: [62, 0, 0, 2176, 4162] Количество пикселей на класс: [0, 0, 0, 1661, 4739] Количество пикселей на класс: [1226, 0, 1, 3125, 2048] Количество пикселей на класс: [1830, 0, 0, 2147, 2423] Количество пикселей на класс: [2114, 0, 6, 1758, 2522] Количество пикселей на класс: [4456, 0, 0, 1944, 0] Количество пикселей на класс: [1497, 0, 23, 2296, 2584] Количество пикселей на класс: [709, 2909, 3, 637, 2142] Количество пикселей на класс: [1066, 0, 50, 1597, 3687] Количество пикселей на класс: [541, 1, 50, 1425, 4383] Количество пикселей на класс: [622, 1, 20, 1721, 4036] Количество пикселей на класс: [0, 0, 2, 1876, 4522] Количество пикселей на класс: [7, 0, 22, 1516, 4855] Количество пикселей на класс: [2865, 0, 1, 1601, 1933] Количество пикселей на класс: [2515, 52, 0, 2485, 1348] Количество пикселей на класс: [1549, 0, 8, 1630, 3213] Количество пикселей на класс: [0, 0, 1, 1308, 5091] Количество пикселей на класс: [4371, 0, 3, 1285, 741] Количество пикселей на класс: [0, 0, 0, 1725, 4675] Количество пикселей на класс: [4077, 0, 0, 2323, 0] Количество пикселей на класс: [1561, 297, 3, 2863, 1676] Количество пикселей на класс: [1703, 0, 11, 1957, 2729] Количество пикселей на класс: [0, 0, 7, 2341, 4052] Количество пикселей на класс: [2680, 0, 3, 2061, 1656] Количество пикселей на класс: [3275, 0, 5, 2583, 537] Количество пикселей на класс: [2117, 0, 5, 2638, 1640] Количество пикселей на класс: [0, 0, 7, 1390, 5003] Количество пикселей на класс: [0, 0, 2, 1965, 4433] Количество пикселей на класс: [4764, 0, 0, 1636, 0] Количество пикселей на класс: [383, 0, 0, 1328, 4689] Количество пикселей на класс: [0, 0, 3, 3670, 2727] Количество пикселей на класс: [4026, 0, 3, 1608, 763] Количество пикселей на класс: [4137, 0, 0, 2263, 0] Количество пикселей на класс: [0, 0, 4, 1412, 4984] Количество пикселей на класс: [3347, 0, 0, 1295, 1758] Количество пикселей на класс: [4860, 0, 0, 1540, 0] Количество пикселей на класс: [4021, 0, 0, 2379, 0] Количество пикселей на класс: [0, 0, 0, 2423, 3977] Количество пикселей на класс: [1616, 0, 25, 1256, 3503] Количество пикселей на класс: [743, 0, 0, 2654, 3003] Количество пикселей на класс: [1585, 0, 17, 1292, 3506] Количество пикселей на класс: [3975, 0, 0, 2425, 0] Количество пикселей на класс: [2152, 1, 8, 1613, 2626] Количество пикселей на класс: [1362, 0, 5, 4385, 648] Количество пикселей на класс: [842, 2, 25, 2003, 3528] Количество пикселей на класс: [4855, 0, 0, 1545, 0] Количество пикселей на класс: [870, 3, 8, 3975, 1544] Количество пикселей на класс: [1564, 0, 1, 3415, 1420] Количество пикселей на класс: [4338, 0, 1, 1541, 520] Количество пикселей на класс: [0, 0, 0, 1830, 4570] Количество пикселей на класс: [4857, 0, 0, 1543, 0] Количество пикселей на класс: [4533, 1553, 314, 0, 0] Количество пикселей на класс: [3560, 2037, 803, 0, 0] Количество пикселей на класс: [0, 0, 9, 2423, 3968] Количество пикселей на класс: [1598, 1661, 0, 0, 3141] Количество пикселей на класс: [0, 0, 0, 3006, 3394] Количество пикселей на класс: [0, 0, 0, 1636, 4764] Количество пикселей на класс: [2104, 1424, 0, 0, 2872] Количество пикселей на класс: [4429, 0, 0, 1971, 0] Количество пикселей на класс: [5035, 1365, 0, 0, 0] Количество пикселей на класс: [270, 0, 0, 1313, 4817] Количество пикселей на класс: [0, 0, 0, 1942, 4458] Количество пикселей на класс: [4463, 1937, 0, 0, 0] Количество пикселей на класс: [4770, 1630, 0, 0, 0] Количество пикселей на класс: [3976, 0, 0, 2424, 0] Количество пикселей на класс: [0, 0, 0, 2877, 3523] Количество пикселей на класс: [0, 0, 0, 2172, 4228] Количество пикселей на класс: [2155, 72, 1, 2658, 1514] Количество пикселей на класс: [0, 0, 1, 2262, 4137] Количество пикселей на класс: [1089, 0, 0, 1342, 3969] Количество пикселей на класс: [1623, 30, 0, 2738, 2009] Количество пикселей на класс: [1593, 1485, 0, 0, 3322] Количество пикселей на класс: [2103, 1711, 0, 0, 2586]
print('New train dataset:')
dataset_info(new_train_dataset)
New train dataset: Размер датасета изображений: torch.Size([956, 80, 80, 3]) Размер датасета масок: torch.Size([956, 80, 80, 5]) Класс: Background, Число пикселей: 1762760(28.81%) Класс: Human divers, Число пикселей: 250992(4.10%) Класс: Robots, Число пикселей: 51116(0.84%) Класс: Fish and vertebrates, Число пикселей: 1698740(27.76%) Класс: Other, Число пикселей: 2354792(38.49%)
Настройка гиперпараметров
learning_rate = 0.01
epochs = 30
batch_size = 36
new_training_dataset, new_validation_dataset = torch.utils.data.random_split(new_train_dataset, [0.85, 0.15])
new_training_loader = DataLoader(new_training_dataset, batch_size = batch_size, shuffle = True)
new_validation_loader = DataLoader(new_validation_dataset, batch_size = batch_size, shuffle = True)
class CustomBCELoss(nn.Module):
def __init__(self, weight):
super(CustomBCELoss, self).__init__()
self.weight = weight
def forward(self, input, target):
loss = []
for k in range(number_classes):
new_input = torch.where(torch.where(input[:, k, :, :] < 0.0001, 0.0001, input[:, k, :, :]) > 0.9999, 0.9999, input[:, k, :, :])
loss.append(torch.where(target[:, k, :, :] == 1, -self.weight[k] * torch.log(new_input), -self.weight[k] * torch.log(1 - new_input)))
return torch.mean(torch.cat(loss, dim=0))
new_net = UNet().to(device)
# weight = [0.4, 0.6, 0.9, 0.7, 0.3]
weight = [1.0, 1.4, 1.3, 1.0, 1.0]
criterion = CustomBCELoss(weight)
optimizer = optim.Adam(new_net.parameters(), lr = learning_rate)
train(new_net, new_training_loader, new_validation_loader, criterion, epochs)
Epoch: 1/30, Train Loss: 0.4680, Time: 4.25 Validation Loss: 0.3826, Time: 0.34 Epoch: 2/30, Train Loss: 0.3456, Time: 4.23 Validation Loss: 0.3274, Time: 0.34 Epoch: 3/30, Train Loss: 0.3188, Time: 4.26 Validation Loss: 0.3179, Time: 0.35 Epoch: 4/30, Train Loss: 0.3031, Time: 4.28 Validation Loss: 0.3216, Time: 0.34 Epoch: 5/30, Train Loss: 0.2952, Time: 4.32 Validation Loss: 0.2988, Time: 0.35 Epoch: 6/30, Train Loss: 0.2849, Time: 4.34 Validation Loss: 0.2905, Time: 0.35 Epoch: 7/30, Train Loss: 0.2792, Time: 4.37 Validation Loss: 0.2769, Time: 0.35 Epoch: 8/30, Train Loss: 0.2692, Time: 4.38 Validation Loss: 0.2754, Time: 0.35 Epoch: 9/30, Train Loss: 0.2583, Time: 4.34 Validation Loss: 0.2731, Time: 0.35 Epoch: 10/30, Train Loss: 0.2616, Time: 4.31 Validation Loss: 0.2603, Time: 0.35 Epoch: 11/30, Train Loss: 0.2475, Time: 4.30 Validation Loss: 0.2644, Time: 0.34 Epoch: 12/30, Train Loss: 0.2562, Time: 4.26 Validation Loss: 0.2697, Time: 0.34 Epoch: 13/30, Train Loss: 0.2407, Time: 4.25 Validation Loss: 0.2548, Time: 0.35 Epoch: 14/30, Train Loss: 0.2363, Time: 4.24 Validation Loss: 0.2513, Time: 0.34 Epoch: 15/30, Train Loss: 0.2275, Time: 4.23 Validation Loss: 0.2381, Time: 0.34 Epoch: 16/30, Train Loss: 0.2153, Time: 4.24 Validation Loss: 0.2432, Time: 0.34 Epoch: 17/30, Train Loss: 0.2260, Time: 4.24 Validation Loss: 0.2353, Time: 0.34 Epoch: 18/30, Train Loss: 0.2176, Time: 4.24 Validation Loss: 0.2334, Time: 0.34 Epoch: 19/30, Train Loss: 0.2114, Time: 4.25 Validation Loss: 0.2328, Time: 0.34 Epoch: 20/30, Train Loss: 0.2106, Time: 4.27 Validation Loss: 0.2219, Time: 0.34 Epoch: 21/30, Train Loss: 0.1961, Time: 4.28 Validation Loss: 0.2284, Time: 0.34 Epoch: 22/30, Train Loss: 0.1870, Time: 4.28 Validation Loss: 0.2033, Time: 0.35 Epoch: 23/30, Train Loss: 0.1755, Time: 4.29 Validation Loss: 0.2141, Time: 0.34 Epoch: 24/30, Train Loss: 0.1719, Time: 4.28 Validation Loss: 0.2203, Time: 0.34 Epoch: 25/30, Train Loss: 0.1911, Time: 4.28 Validation Loss: 0.2060, Time: 0.34 Epoch: 26/30, Train Loss: 0.1750, Time: 4.28 Validation Loss: 0.2038, Time: 0.34 Epoch: 27/30, Train Loss: 0.1705, Time: 4.28 Validation Loss: 0.2067, Time: 0.34 Epoch: 28/30, Train Loss: 0.1750, Time: 4.27 Validation Loss: 0.2129, Time: 0.34 Epoch: 29/30, Train Loss: 0.1587, Time: 4.26 Validation Loss: 0.1959, Time: 0.35 Epoch: 30/30, Train Loss: 0.1536, Time: 4.27 Validation Loss: 0.1983, Time: 0.35
metrics_compute(new_net, test_loader)
Оценка Accuracy для каждого класса: [0.832 0.967 0.9937 0.8057 0.7236] Оценка IoU для каждого класса: [0.621 0.1132 0.11914 0.1781 0.5303 ] Оценка Accuracy на данных: 0.8647 Оценка IoU на данных: 0.3123
plot_images_with_masks_test(new_net, test_dataset)